/*
 * Copyright (C) 2015 The Pennsylvania State University and the University of Wisconsin
 * Systems and Internet Infrastructure Security Laboratory
 *
 * Author: Damien Octeau
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package edu.psu.cse.siis.ic3.db;

import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

import edu.psu.cse.siis.coal.Constants;

public class StringTable extends Table {
  private static final String INSERT = "INSERT INTO \"%s\" (st) VALUES (?)";
  private static final String FIND = "SELECT id FROM \"%s\" WHERE st = ?";
  private static final String BATCH_INSERT = INSERT;
  private static final String BATCH_FIND = "SELECT id, st FROM \"%s\" WHERE st IN (?";

  StringTable(String table) {
    insertString = String.format(INSERT, table);
    findString = String.format(FIND, table);
    batchInsertString = String.format(BATCH_INSERT, table);
    batchFindString = String.format(BATCH_FIND, table);
  }

  public Map<String, Integer> batchFind(Set<String> strings) throws SQLException {
    Map<String, Integer> result = new HashMap<String, Integer>();
    if (strings == null || strings.size() == 0 || (strings.size() == 1 && strings.contains(null))) {
      return result;
    }
    StringBuilder queryBuilder = new StringBuilder(batchFindString);
    for (int i = 1; i < strings.size(); ++i) {
      queryBuilder.append(", ?");
    }
    queryBuilder.append(")");

    PreparedStatement batchFindStatement =
        getConnection().prepareStatement(queryBuilder.toString());

    int parameterIndex = 1;
    for (String string : strings) {
      if (string == null) {
        string = Constants.NULL_STRING;
      }
      batchFindStatement.setString(parameterIndex++, string);
    }

    ResultSet resultSet = batchFindStatement.executeQuery();
    while (resultSet.next()) {
      result.put(resultSet.getString("st"), resultSet.getInt("id"));
    }
    return result;
  }

  public Set<Integer> batchInsert(Set<String> strings, boolean[] allThere) throws SQLException {
    if (strings == null || (strings.size() == 1 && strings.contains(null))) {
      if (allThere != null) {
        // System.out.println("TEST1");
        allThere[0] = true;
      }
      return new HashSet<Integer>();
    }
    Map<String, Integer> alreadyThere = batchFind(strings);
    // System.out.println("Found " + alreadyThere + " for " + strings);
    Set<String> toBeInserted = new HashSet<String>(strings);
    toBeInserted.removeAll(alreadyThere.keySet());
    Set<Integer> result = new HashSet<Integer>(alreadyThere.values());
    if (toBeInserted.size() == 0) {
      if (allThere != null) {
        // System.out.println("TEST2");
        allThere[0] = true;
      }
      return result;
    }

    StringBuilder queryBuilder = new StringBuilder(batchInsertString);
    for (int i = 1; i < toBeInserted.size(); ++i) {
      queryBuilder.append(", (?)");
    }

    PreparedStatement batchInsertStatement =
        getConnection().prepareStatement(queryBuilder.toString(), AUTOGENERATED_ID);
    int parameterIndex = 1;
    for (String string : toBeInserted) {
      if (string == null) {
        string = Constants.NULL_STRING;
      }
      batchInsertStatement.setString(parameterIndex++, string);
    }
    batchInsertStatement.executeUpdate();
    ResultSet resultSet = batchInsertStatement.getGeneratedKeys();
    while (resultSet.next()) {
      result.add(resultSet.getInt(1));
    }
    // System.out.println("TEST3");
    return result;
  }

  public int insert(String st) throws SQLException {
    if (st == null) {
      st = Constants.NULL_STRING;
    }
    int id = find(st);
    if (id != NOT_FOUND) {
      return id;
    }
    return forceInsert(st);
  }

  public int forceInsert(String st) throws SQLException {
    if (insertStatement == null || insertStatement.isClosed()) {
      insertStatement = getConnection().prepareStatement(insertString + " returning id;");
    }
    if (st == null) {
      st = Constants.NULL_STRING;
    }
    insertStatement.setString(1, st);
    // if (insertStatement.executeUpdate() == 0) {
    // return NOT_FOUND;
    // }
    // return findAutoIncrement();
    return processIntFindQuery(insertStatement);
  }

  public int find(String st) throws SQLException {
    if (findStatement == null || findStatement.isClosed()) {
      findStatement = getConnection().prepareStatement(findString);
    }
    if (st == null) {
      st = Constants.NULL_STRING;
    }
    findStatement.setString(1, st);
    return processIntFindQuery(findStatement);
  }
}
